Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13568] [ML] Create feature transformer to impute missing values #11601

Closed
wants to merge 54 commits into from

Conversation

hhbyyh
Copy link
Contributor

@hhbyyh hhbyyh commented Mar 9, 2016

What changes were proposed in this pull request?

jira: https://issues.apache.org/jira/browse/SPARK-13568
It is quite common to encounter missing values in data sets. It would be useful to implement a Transformer that can impute missing data points, similar to e.g. Imputer in scikit-learn.
Initially, options for imputation could include mean, median and most frequent, but we could add various other approaches, where possible existing DataFrame code can be used (e.g. for approximate quantiles etc).

Currently this PR supports imputation for Double and Vector (null and NaN in Vector).

How was this patch tested?

new unit tests and manual test

@SparkQA
Copy link

SparkQA commented Mar 9, 2016

Test build #52734 has finished for PR 11601 at commit 1b39668.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

val strategy: Param[String] = new Param(this, "strategy", "strategy for imputation. " +
"If mean, then replace missing values using the mean along the axis." +
"If median, then replace missing values using the median along the axis." +
"If most, then replace missing using the most frequent value along the axis.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a param validation function since there are a limited number of valid strategies? You can add an attribute like val supportedMissingValueStrategies = Set("mean", "median", "most") to the Imputer companion object like is done here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the validation to validateParameter. (which should be moved since it's the deprecated). Thanks for the suggestion. I'll add them.

@sethah
Copy link
Contributor

sethah commented Mar 9, 2016

Looking at the Jiras, it is unclear if any concrete decisions were made regarding handling Vectors and how NaN values should be handled in colStats. Is there any update?

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Mar 10, 2016

I prefer to keep Statistics.colStats(rdd) unchanged for now. As ut in this PR suggests, we can cover Double and Vector for now.

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Mar 10, 2016

@sethah @MLnick Thanks for helping with review. I made a pass according to the comments and add some more comments.

@SparkQA
Copy link

SparkQA commented Mar 10, 2016

Test build #52842 has finished for PR 11601 at commit 4e45f81.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

val colStatistics = $(strategy) match {
case "mean" =>
filteredDF.selectExpr(s"avg($colName)").first().getDouble(0)
case "median" =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should favour using the new approxQuantile sql stat function here rather than computing exactly.

@SparkQA
Copy link

SparkQA commented Mar 23, 2016

Test build #53923 has finished for PR 11601 at commit 1b36deb.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 23, 2016

Test build #53931 has finished for PR 11601 at commit 72d104d.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Feb 22, 2017

Test build #73268 has started for PR 11601 at commit e86d919.

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Feb 22, 2017

/** @group getParam */
def getMissingValue: Double = $(missingValue)

/**
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix comment indentation here.

* All Null values in the input column are treated as missing, and so are also imputed.
*/
@Experimental
class Imputer @Since("2.1.0")(override val uid: String)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All @Since annotations -> 2.2.0

/**
* Params for [[Imputer]] and [[ImputerModel]].
*/
private[feature] trait ImputerParams extends Params with HasInputCols with HasOutputCol {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use HasOutputCol anymore, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, however I didn't get your first comment. Do you mean we should remove the import?

object Imputer extends DefaultParamsReadable[Imputer] {

/** Set of strategy names that Imputer currently supports. */
private[ml] val supportedStrategyNames = Set("mean", "median")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we factor out the mean and median names in to private[ml] val so to be used instead of the raw strings throughout?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's better.

case "mean" => filtered.select(avg(inputCol)).first().getDouble(0)
case "median" => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001)(0)
}
surrogate.asInstanceOf[Double]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the asInstanceOf[Double] necessary here?

Copy link
Contributor Author

@hhbyyh hhbyyh Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, will remove it.

test("ImputerModel read/write") {
val spark = this.spark
import spark.implicits._
val surrogateDF = Seq(1.234).toDF("myInputCol")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be "surrogate" col name - though I see we don't actually use it in load or transform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this happens to be the correct column name for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok - we should add a test here to check the column names of instance and newInstance match up? (The below check is just for the actual values of the surrogate, correct?

var outputDF = dataset
val surrogates = surrogateDF.head().getSeq[Double](0)

$(inputCols).indices.foreach { i =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could do $(inputCols).zip($(outputCols)).zip(surrogates).map { case ((inputCol, outputCol), icSurrogate) => ...

val localOutputCols = $(outputCols)
var outputSchema = schema

$(inputCols).indices.foreach { i =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do $(inputCols).zip($(outputCols)).foreach { case (inputCol, outputCol) => ...

}
val surrogate = $(strategy) match {
case "mean" => filtered.select(avg(inputCol)).first().getDouble(0)
case "median" => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001)(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.head

* Model fitted by [[Imputer]].
*
* @param surrogateDF Value by which missing values in the input columns will be replaced. This
* is stored using DataFrame with input column names and the corresponding surrogates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is misleading - you're just storing the array of surrogates... did you mean something different? Otherwise the comment must be changed,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like you had the idea of storing the surrogates something like:

+------+---------+
|column|surrogate|
+------+---------+
|  col1|      1.2|
|  col2|      3.4|
|  col3|      5.4|
+------+---------+

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored it a little for better extensibility.

inputCol1 inputCol2
surrogate1 surrogate2

@MLnick
Copy link
Contributor

MLnick commented Mar 2, 2017

jenkins retest this please

@SparkQA
Copy link

SparkQA commented Mar 2, 2017

Test build #73753 has finished for PR 11601 at commit e86d919.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Mar 2, 2017

Thanks a lot for making a pass @MLnick. The last update mainly focused on the interface and behavior change. I'll make a pass and also address your comments.

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Mar 3, 2017

Hi @MLnick I changed the surrogateDF format for better extensibility in the last update and added unit tests for multi-column support. Let me know if I miss anything.

inputCol1 inputCol2
surrogate1 surrogate2

@SparkQA
Copy link

SparkQA commented Mar 3, 2017

Test build #73868 has finished for PR 11601 at commit 41d91b9.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class Imputer @Since(\"2.2.0\")(override val uid: String)

Copy link
Contributor

@MLnick MLnick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a pass. A few minor comments.

* The imputation strategy.
* If "mean", then replace missing values using the mean value of the feature.
* If "median", then replace missing values using the approximate median value of the
* feature (relative error less than 0.001).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think remove the part (relative error less than 0.001).

This can be moved to the overall ScalaDoc for Imputer at L95.

/**
* :: Experimental ::
* Imputation estimator for completing missing values, either using the mean or the median
* of the column in which the missing values are located. The input column should be of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above at https://github.com/apache/spark/pull/11601/files#r104403880, you can add the note about relative error here.

Something like "For computing median, approxQuantile is used with a relative error of X" (provide a ScalaDoc link to approxQuantile).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add the link as it may break java doc generation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right - perhaps just mention using approxQuantile?

@Since("2.2.0")
def setMissingValue(value: Double): this.type = set(missingValue, value)

import org.apache.spark.ml.feature.Imputer._
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import should probably be above with the others (or within fit)

}
val surrogate = $(strategy) match {
case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first()
case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really sure about the relative error here - perhaps 0.01 is sufficient?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later perhaps we can even expose it as an expert param (but not for now)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it before. 0.01 and 0.001 actually takes the same time for even a large dataset. Agree we can make it a param later.

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
var outputDF = dataset
val surrogates = surrogateDF.select($(inputCols).head, $(inputCols).tail: _*).head().toSeq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is slightly cleaner: surrogateDF.select($(inputCols).map(col): _*)

.setInputCols(Array("value1", "value2"))
.setOutputCols(Array("out1"))
.setStrategy(strategy)
intercept[IllegalArgumentException] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test for thrown message here and withClue

test("ImputerModel read/write") {
val spark = this.spark
import spark.implicits._
val surrogateDF = Seq(1.234).toDF("myInputCol")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok - we should add a test here to check the column names of instance and newInstance match up? (The below check is just for the actual values of the surrogate, correct?


}

object ImputerSuite{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space before {

Seq("mean", "median").foreach { strategy =>
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out"))
.setStrategy(strategy)
intercept[SparkException] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check message here also.

)).toDF("id", "value1", "value2", "value3")
Seq("mean", "median").foreach { strategy =>
// inputCols and outCols length different
val imputer = new Imputer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also perhaps use withClue to put a message for the subtest / exception assertion (e.g. withClue("Imputer should fail if inputCols and outputCols are different length")

@SparkQA
Copy link

SparkQA commented Mar 6, 2017

Test build #74038 has finished for PR 11601 at commit e378db5.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -99,7 +98,8 @@ private[feature] trait ImputerParams extends Params with HasInputCols {
* (SPARK-15041) and possibly creates incorrect values for a categorical feature.
*
* Note that the mean/median value is computed after filtering out missing values.
* All Null values in the input column are treated as missing, and so are also imputed.
* All Null values in the input column are treated as missing, and so are also imputed. For
* computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see it is here - nevermind

val ic = col(inputCol)
val filtered = dataset.select(ic.cast(DoubleType))
.filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN)
if(filtered.rdd.isEmpty()) {
Copy link
Contributor

@MLnick MLnick Mar 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do filtered.take(1).size == 0 which should be more efficient

.filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN)
if(filtered.rdd.isEmpty()) {
throw new SparkException(s"surrogate cannot be computed. " +
s"All the values in $inputCol are Null, Nan or missingValue ($missingValue)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

($missingValue) -> ${$(missingValue)}?

@MLnick
Copy link
Contributor

MLnick commented Mar 8, 2017

Made a few last comments. LGTM.

cc @sethah @jkbradley I am going to merge this for 2.2. Let me know if you have any final comments.

@MLnick
Copy link
Contributor

MLnick commented Mar 8, 2017

By the way out of curiosity, I tested things out on a cluster (4x workers, 192 cores & 480GB RAM total), with 100 columns of 100 million doubles each, 1% NaN occurrence. Reading from a Parquet file.

not cached
fit takes about 1.5 seconds per column (150 secs), while transform takes 50 secs.

cached
fit: 15 sec; transform: 16 sec.

@SparkQA
Copy link

SparkQA commented Mar 8, 2017

Test build #74216 has finished for PR 11601 at commit c67afc1.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Mar 8, 2017

Thanks @MLnick for being the Shepherd and providing consistent help on discussion and review. The performance test matches what I got from my local environment.

@MLnick
Copy link
Contributor

MLnick commented Mar 16, 2017

jenkins retest this please

@MLnick
Copy link
Contributor

MLnick commented Mar 16, 2017

Created SPARK-19969 to track doc and examples to be done for 2.2 release. I can help with this if you're tied up.

@SparkQA
Copy link

SparkQA commented Mar 16, 2017

Test build #74651 has finished for PR 11601 at commit c67afc1.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@MLnick
Copy link
Contributor

MLnick commented Mar 16, 2017

Merged to master. Thanks @hhbyyh and also everyone for reviews.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants